import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import synthesize_results, get_circular_lineout, calc_pcfrc, calc_pcmse, calc_sqrt_fidelity
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator, LogLocator)

# This gives the plots a clear background when saved to PDF
plt.rcParams.update({
    "figure.facecolor":  (1.0, 1.0, 1.0, 0.0),  # clear
    "axes.facecolor":    (1.0, 1.0, 1.0, 1.0),  # white
    "savefig.facecolor": (1.0, 1.0, 1.0, 0.0),  # clear
})

# If true, it saves out the synthesized (half/half/full) reconstructions.
save_synthesized = True

# If true, it saves the relevant figures. If false, it simply plots them.
save_figures = False
if save_figures:
    from pathlib import Path
    Path('figs').mkdir(exist_ok=True)

# This is the ptychography dataset, it used single shots for each exposure.
filename = 'single_shot_ptycho.cxi'

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi('data/' + filename,
                                                    cut_zeros=False)

# This helps us measure the signal level based on the detector values
electronsperphoton = 60 / 3.6
ADUperelectron =2**16 / 200000 # from the MTE-2048B datasheet
ADUperphoton = ADUperelectron * electronsperphoton

# In one preliminary computational experiment, we compared the quality of
# super-resolution ptychography reconstructions at different padding
# levels. We defined the approximate center of the Siements star in this way
# to enable that comparison, although for the final analysis this is not
# necessary.
base_center = (906, 898)
base_radius = (1142 - 670) // 2 + 50

def make_window(padding):
    """Returns a slice which extracts a consistently sized FOV"""
    pixel_ratio = (1024 + padding*2)/1024
    new_center = [int((loc-200) * pixel_ratio + 200) for loc in base_center]
    new_radius = int(base_radius * pixel_ratio)
    return np.s_[new_center[0]-new_radius : new_center[0]+new_radius,
                 new_center[1]-new_radius : new_center[1]+new_radius]

# The padding was 200 pixels for the included ptychography data
padding = 200

# We load all three reconstructions
half_1 = loadmat(('results/' + '.'.join(filename.split('.')[:-1])
                  + '_half_1_polished_%dpad.mat' % padding))
half_2 = loadmat(('results/' + '.'.join(filename.split('.')[:-1])
                + '_half_2_polished_%dpad.mat' % padding))
full = loadmat(('results/' + '.'.join(filename.split('.')[:-1])
                + '_full_polished_%dpad.mat' % padding))

# This corrects for a change in CDTools between when the reconstructions used
# for the paper were run, and the latest resubmission
if 'basis' not in half_1:
    half_1['basis'] = half_1['obj_basis']
    half_1['translation'] = half_1['translations']
    half_2['basis'] = half_2['obj_basis']
    half_2['translation'] = half_2['translations']
    full['basis'] = full['obj_basis']
    full['translation'] = full['translations']

# And now we use this helper function to shift them all to a common frame
# and calculate the key metrics.
window = make_window(padding)


# This does the comparison between the two 50% data reconstructions and the
# full data result, correcting for phase ramps, shifts, and an overall
# amplitude exponent which is left unconstrained by the introduction of
# a per-exposure weighting factor.
results = synthesize_results(half_1, half_2, full, window)

if save_synthesized:
    savemat(('results/' + '.'.join(filename.split('.')[:-1])
             + '_synthesized_polished_%dpad.mat' % padding), results)





#
# Plotting Section
#

# We get shorthands for some commonly used variables
basis = results['basis']
freqs = results['frc_freqs']
ssnr = results['ssnr']
frc = results['frc']
threshold = results['frc_threshold']
psize = np.abs(basis[0,1])
psize_A = int(np.round(psize*1e10)) # pixel size in angstroms, 3-4 sig figs
print('Pixel size in Angstrom:', psize_A)

intensities = np.sum(np.abs(results['probe_full'])**2, axis=(1,2))
intensities = intensities * 100 / np.sum(intensities)
plt.figure(figsize=(3.75,2.5))
plt.bar(list(range(len(intensities))), intensities)
plt.xlabel('Mode Number')
plt.ylabel('Power Fraction (%)')
plt.tight_layout()


def savefig_plus_im(fname):
    """Saves the current figure, and the displayed im in full resolution.

    The figure gets saved as a pdf, but with imshow images, that pdf will
    have a sampled/rescaled version of the image. This function saves out
    a full resolution png at the same time.
    """
    plt.savefig(fname + '.pdf')
    im = plt.gca().get_images()[0]
    plt.imsave(fname + '_%dApix.png' % psize_A, im.get_array(), cmap=im.cmap)


# First, we plot some raw data
intensities = t.sum(dataset.patterns, dim=(1,2)).numpy() / ADUperphoton
mean_signal = np.mean(intensities)
std_signal = np.std(intensities)

print('Mean ptycho signal Level:', int(mean_signal), 'photons')
print('STD ptycho signal Level:', int(std_signal), 'photons')

dataset.patterns = t.clamp(dataset.patterns, min=0)

p.plot_real(np.log10(dataset.patterns[0].numpy() + 1))
plt.title('Log base 10 of ptycho detector data + 1 (ADU)')
if save_figures:
    savefig_plus_im('figs/ptycho_data_adu')

p.plot_real(np.log10(dataset.patterns[0].numpy() / ADUperphoton  + 1))
plt.title('Log base 10 of ptycho detector data + 1 (Photons)')
if save_figures:
    savefig_plus_im('figs/ptycho_data_photons')

# Second, we plot the ptychography images 
a = p.plot_amplitude(results['obj_full'][window], basis=basis)
plt.title('Amplitude of the full ptychography result')
if save_figures:
    savefig_plus_im('figs/ptycho_obj_amplitude')

p.plot_amplitude(results['obj_half_1'][window], basis=basis)
plt.title('Amplitude of the first 50\% ptychography result')
if save_figures:
    savefig_plus_im('figs/ptycho_half1_obj_amplitude')

p.plot_amplitude(results['obj_half_2'][window], basis=basis)
plt.title('Amplitude of the second 50\% ptychography result')
if save_figures:
    savefig_plus_im('figs/ptycho_half2_obj_amplitude')
    
p.plot_phase(results['obj_full'][window], basis=basis)
plt.title('Phase of the full ptychography result')
if save_figures:
    savefig_plus_im('figs/ptycho_obj_phase')
    
p.plot_colorized(results['obj_full'][window], basis=basis)
plt.title('Colorized Image of the full ptychography result')
if save_figures:
    savefig_plus_im('figs/ptycho_obj_colorized')

p.plot_real(t.log(t.abs(cdtools.tools.propagators.far_field(
    t.as_tensor(results['obj_full'][window]))))/np.log(10))
plt.title('Log Base 10 Amplitude of FFT of ptychography reconstruction')
if save_figures:
    savefig_plus_im('figs/ptycho_obj_fourier_amplitude')

p.plot_amplitude(results['illumination_map_full'][window], basis=basis)
plt.title('Illumination Intensity (photons per pixel, wrong)')


print('Probe basis:', results['original_probe_basis'])
p.plot_amplitude(np.sum(np.abs(results['probe_full'])**2, axis=0),
                 basis=results['original_probe_basis'])
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_intensity')
p.plot_amplitude(results['probe_full'][0],basis=results['original_probe_basis'])
plt.title('Amplitude of the full ptycho top probe mode')
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_amplitude_mode_1')
p.plot_amplitude(results['probe_full'][1],basis=results['original_probe_basis'])
plt.title('Amplitude of the full ptycho second probe mode')
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_amplitude_mode_2')
p.plot_amplitude(results['probe_full'][2],basis=results['original_probe_basis'])
plt.title('Amplitude of the full ptycho third probe mode')
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_amplitude_mode_3')
p.plot_colorized(results['probe_full'][0],basis=results['original_probe_basis'])
plt.title('Colorized full ptycho top probe mode')
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_colorized_mode_1')
p.plot_colorized(results['probe_full'][1],basis=results['original_probe_basis'])
plt.title('Colorized full ptycho second probe mode')
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_colorized_mode_2')
p.plot_colorized(results['probe_full'][2],basis=results['original_probe_basis'])
plt.title('Colorized full ptycho third probe mode')
if save_figures:
    savefig_plus_im('figs/ptycho_probe_full_colorized_mode_3')

plt.figure(figsize=(3.75,2.5))
plt.plot(freqs*1e-6, frc, 'k-')
plt.plot(freqs* 1e-6, threshold, 'k--')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('FRC')
plt.grid(True)
plt.tight_layout()
plt.figure(figsize=(3.75,2.5))
plt.semilogy(freqs*1e-6, results['ssnr'], 'k-')
plt.gca().yaxis.get_minor_locator().set_params(numticks=99, subs=[.2, .4, .6, .8])
plt.grid(True, which='both')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('SSNR')
plt.tight_layout()


p1 = results['probe_half_1']
p2 = results['probe_half_2']

probe_shift = cdtools.tools.image_processing.find_shift(
    t.as_tensor(np.sum(np.abs(p1)**2, axis=0)),
    t.as_tensor(np.sum(np.abs(p2)**2, axis=0)))
           
shifted_p2 = np.stack([cdtools.tools.image_processing.sinc_subpixel_shift(t.as_tensor(pr), probe_shift).numpy() for pr in p2])

fft_p1 = cdtools.tools.propagators.far_field(t.as_tensor(p1)).numpy()
fft_p2 = cdtools.tools.propagators.far_field(t.as_tensor(shifted_p2)).numpy()

probe_fft_shift = cdtools.tools.image_processing.find_shift(
    t.as_tensor(np.sum(np.abs(fft_p1)**2, axis=0)),
    t.as_tensor(np.sum(np.abs(fft_p2)**2, axis=0)))

           
shifted_fft_p2 = np.stack([cdtools.tools.image_processing.sinc_subpixel_shift(t.as_tensor(pr), probe_fft_shift).numpy() for pr in fft_p2])

final_p2 = cdtools.tools.propagators.inverse_far_field(t.as_tensor(shifted_fft_p2)).numpy()



freqs, pcfrc = calc_pcfrc(t.as_tensor(p1), t.as_tensor(final_p2), 40)

freqs *= 1 / (np.abs(results['original_probe_basis'][0,1]))

pcmse = calc_pcmse(t.as_tensor(p1), t.as_tensor(final_p2))
print('PCMSE between probes:', pcmse.cpu().numpy())
sqrtfid = calc_sqrt_fidelity(t.as_tensor(p1), t.as_tensor(final_p2))
norm_pcmse = 1 - sqrtfid**2 / (np.sum(np.abs(p1)**2)*np.sum(np.abs(p2)**2))
print('Normalized PCSME between probes:', norm_pcmse.cpu().numpy())
print('sqrt fidelity between probes:', sqrtfid.cpu().numpy())


plt.figure(figsize=(3.75,2.5))
plt.plot(freqs*1e-6, pcfrc, 'k-')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('PCFRC')
plt.xlim([-0.25, 5])
plt.grid(True)
plt.tight_layout()
if save_figures:
    plt.savefig('figs/probe_pcfrc.pdf')
    
plt.show()


